package com.xiam.consia.ml_new.attributeselection;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Maps;
import com.google.common.collect.Ordering;
import com.google.common.collect.Sets;
import com.xiam.consia.featurecapture.store.FeatureSample;
import com.xiam.consia.featurecapture.store.FeatureSampleStore;
import com.xiam.consia.featurecapture.store.attributes.Attribute;
import com.xiam.consia.featurecapture.store.attributes.AttributeStore;
import com.xiam.consia.ml_new.tree.SplitInfoBuilder;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: classes.dex */
public class InformationGain extends AttributeSelection {
    private final AttributeStore attributeStore;

    /* JADX INFO: Access modifiers changed from: package-private */
    public InformationGain(AttributeStore attributeStore, boolean z, int i) {
        super(z, i);
        this.attributeStore = attributeStore;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Short asShort(long j) {
        return Short.valueOf((short) j);
    }

    private static double calcClassEntropy(Collection<Short> collection, double d) {
        Iterator<Short> it = collection.iterator();
        double d2 = 0.0d;
        while (it.hasNext()) {
            double doubleValue = it.next().doubleValue() / d;
            d2 = doubleValue > 0.0d ? (doubleValue * Math.log(doubleValue)) + d2 : d2;
        }
        return -d2;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static double calcY(Map<String, Short> map, double d, Set<String> set) {
        if (d <= 0.0d) {
            return -0.0d;
        }
        double d2 = 0.0d;
        for (String str : set) {
            if (map.containsKey(str)) {
                double doubleValue = map.get(str).doubleValue() / d;
                if (doubleValue > 0.0d) {
                    d2 += doubleValue * Math.log(doubleValue);
                }
            }
            d2 = d2;
        }
        return -d2;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static double calculateInformationGain(Map<String, Short> map, AttributeCounter attributeCounter, int i) {
        return calcClassEntropy(map.values(), i) - attributeCounter.calcEntropyConditionedOnAttribute(map.keySet(), i);
    }

    private Attribute getAttribute(Attribute attribute, ImmutableList<Attribute> immutableList, Set<Attribute> set) {
        if (!isDoRandomisation()) {
            return attribute;
        }
        Attribute pickRandomAttributeName = pickRandomAttributeName(immutableList);
        while (set.contains(pickRandomAttributeName)) {
            pickRandomAttributeName = pickRandomAttributeName(immutableList);
        }
        set.add(pickRandomAttributeName);
        return pickRandomAttributeName;
    }

    private int getNumAttributesToEvaluate() {
        return isDoRandomisation() ? getNumRandomAttsForSplitCriteria() : this.attributeStore.getNumAttributes();
    }

    private Attribute pickRandomAttributeName(ImmutableList<Attribute> immutableList) {
        return immutableList.get(random.nextInt(this.attributeStore.getNumAttributes()));
    }

    private void setGainAndSplitInfoForContinuousAttribute(Iterable<FeatureSample> iterable, int i, String str, Map<String, Short> map, Map<String, Double> map2, Map<String, Double> map3) {
        ImmutableList<FeatureSample> sortRecordsByAttribute = sortRecordsByAttribute(iterable, str);
        ContinuousAttributeCounter create = ContinuousAttributeCounter.create(i, map);
        map2.put(str, Double.valueOf(Double.MIN_VALUE));
        map3.put(str, Double.valueOf(-1.0d));
        double floatValue = sortRecordsByAttribute.get(0).getAttributeByName(str).getFloatValue();
        Iterator it = sortRecordsByAttribute.iterator();
        int i2 = 1;
        double d = floatValue;
        while (it.hasNext()) {
            FeatureSample featureSample = (FeatureSample) it.next();
            double floatValue2 = featureSample.getAttributeByName(str).getFloatValue();
            if (floatValue2 > d) {
                double calculateGain = calculateGain(map, create, sortRecordsByAttribute.size());
                if (calculateGain > map2.get(str).doubleValue()) {
                    map2.put(str, Double.valueOf(calculateGain));
                    map3.put(str, Double.valueOf((d + floatValue2) / 2.0d));
                }
            }
            create.updateValueCounts(sortRecordsByAttribute.size(), i2);
            create.updateAttributeCountsPerClassMaps(featureSample.getResultClass());
            d = floatValue2;
            i2++;
        }
    }

    private void setGainForAttribute(Iterable<FeatureSample> iterable, int i, boolean z, String str, Map<String, Short> map, Map<String, Double> map2, Map<String, Double> map3) {
        if (!z) {
            setGainAndSplitInfoForContinuousAttribute(iterable, i, str, map, map2, map3);
        } else {
            setGainForDiscreteAttribute(iterable, i, str, map, map2);
            map3.put(str, Double.valueOf(Double.MIN_VALUE));
        }
    }

    private void setGainForDiscreteAttribute(Iterable<FeatureSample> iterable, int i, String str, Map<String, Short> map, Map<String, Double> map2) {
        map2.put(str, Double.valueOf(calculateGain(map, DiscreteAttributeCounter.createDiscrete(iterable, str), i)));
    }

    private static ImmutableList<FeatureSample> sortRecordsByAttribute(Iterable<FeatureSample> iterable, final String str) {
        return new Ordering<FeatureSample>() { // from class: com.xiam.consia.ml_new.attributeselection.InformationGain.1
            @Override // com.google.common.collect.Ordering, java.util.Comparator
            public int compare(FeatureSample featureSample, FeatureSample featureSample2) {
                return Double.compare(featureSample.getAttributeByName(str).getFloatValue(), featureSample2.getAttributeByName(str).getFloatValue());
            }
        }.immutableSortedCopy(iterable);
    }

    protected double calculateGain(Map<String, Short> map, AttributeCounter attributeCounter, int i) {
        return calculateInformationGain(map, attributeCounter, i);
    }

    @Override // com.xiam.consia.ml_new.attributeselection.AttributeSelection
    public SplitInfoBuilder findOptimumSplit(FeatureSampleStore featureSampleStore, Map<String, Short> map) {
        return findOptimumSplit(featureSampleStore, map, Maps.newHashMapWithExpectedSize(this.attributeStore.getNumAttributes()));
    }

    @VisibleForTesting
    SplitInfoBuilder findOptimumSplit(FeatureSampleStore featureSampleStore, Map<String, Short> map, Map<String, Double> map2) {
        boolean z;
        double d;
        int numAttributesToEvaluate = getNumAttributesToEvaluate();
        double d2 = Double.MIN_VALUE;
        String str = "";
        double d3 = Double.MIN_VALUE;
        boolean z2 = false;
        HashMap newHashMapWithExpectedSize = Maps.newHashMapWithExpectedSize(this.attributeStore.getNumAttributes());
        HashSet newHashSet = Sets.newHashSet();
        ImmutableList<Attribute> attributes = this.attributeStore.getAttributes();
        int i = 0;
        while (i < numAttributesToEvaluate) {
            Attribute attribute = getAttribute(attributes.get(i), attributes, newHashSet);
            setGainForAttribute(featureSampleStore.getFeatureRecords(), featureSampleStore.getNumRecords(), !attribute.isNum(), attribute.getName(), map, map2, newHashMapWithExpectedSize);
            String name = attribute.getName();
            double doubleValue = map2.get(name).doubleValue();
            if (doubleValue <= 0.0d || doubleValue <= d2) {
                z = z2;
                d = d3;
                name = str;
                doubleValue = d2;
            } else {
                d = newHashMapWithExpectedSize.get(name).doubleValue();
                z = !attribute.isNum();
            }
            i++;
            z2 = z;
            d3 = d;
            str = name;
            d2 = doubleValue;
        }
        return SplitInfoBuilder.build(str, d3, z2, d2 > 0.0d);
    }
}
